/*
 * Written by Dawid Kurzyniec and released to the public domain, as explained
 * at http://creativecommons.org/licenses/publicdomain
 */

package edu.emory.mathcs.util.concurrent;

import java.util.*;
import edu.emory.mathcs.util.collections.WeakIdentityHashMap;

/**
 * Version of a {@link InheritableThreadLocal} that can be propagated to
 * worker threads by means of {@link ThreadContext}.
 *
 * @see ThreadContext
 *
 * @author Dawid Kurzyniec
 * @version 1.0
 */
public class DelegatableThreadLocal extends InheritableThreadLocal {

    private static ThreadLocalMap threadLocalMap = new ThreadLocalMap();

    /**
     * Creates a new instance of DelegatableThreadLocal
     */
    public DelegatableThreadLocal() {}

    public Object get() {
        Map tlmap = (Map)threadLocalMap.get();
        Object value = tlmap.get(this);
        if (value != null) return value;
        if (tlmap.containsKey(this)) return null;
        value = initialValue();
        tlmap.put(this, value);
        return value;
    }

    public void set(Object value) {
        Map tlmap = (Map)threadLocalMap.get();
        tlmap.put(this, value);
    }

    protected Object initialValue() {
        return null;
    }

    protected Object childValue(Object parentValue) {
        return parentValue;
    }

    protected Object snapshotValue(Object origValue) {
        return origValue;
    }

    protected Object delegatedValue(Object snapshotValue) {
        return snapshotValue;
    }

    protected void onUndelegated(Object undelegatedValue) {}

    private static class ThreadLocalMap extends InheritableThreadLocal {
        protected Object initialValue() {
            return new WeakIdentityHashMap();
        }
        protected Object childValue(Object parentValue) {
            // make sure that the child thread gets its own map so that
            // it does not mess up parent's map by adding/changing values
            Map oldmap = (Map)parentValue;
            Map newmap = new WeakIdentityHashMap(oldmap.size());
            for (Iterator itr=oldmap.entrySet().iterator(); itr.hasNext();) {
                Map.Entry entry = (Map.Entry)itr.next();
                DelegatableThreadLocal dtl = (DelegatableThreadLocal)entry.getKey();
                newmap.put(dtl, dtl.childValue(entry.getValue()));
            }
            return newmap;
        }
    }

    /**
     * Used by ThreadContext.
     */
    static Map takeSnapshot() {
        Map map = (Map)threadLocalMap.get();
        Map snapshot = new WeakIdentityHashMap(map.size());
        for (Iterator itr=map.entrySet().iterator(); itr.hasNext();) {
            Map.Entry entry = (Map.Entry)itr.next();
            DelegatableThreadLocal dtl = (DelegatableThreadLocal)entry.getKey();
            snapshot.put(dtl, dtl.snapshotValue(entry.getValue()));
        }
        return snapshot;
    }

    /**
     * Used by ThreadContext.
     */
    static Map delegate(Map snapshot) {
        Map oldmap = (Map)threadLocalMap.get();
        Map newmap = new WeakIdentityHashMap(snapshot.size());
        for (Iterator itr=snapshot.entrySet().iterator(); itr.hasNext();) {
            Map.Entry entry = (Map.Entry)itr.next();
            DelegatableThreadLocal dtl = (DelegatableThreadLocal)entry.getKey();
            newmap.put(dtl, dtl.delegatedValue(entry.getValue()));
        }
        threadLocalMap.set(newmap);
        return oldmap;
    }

    /**
     * Used by ThreadContext.
     */
    static void restore(Map oldmap) {
        Map map = (Map)threadLocalMap.get();
        for (Iterator itr=map.entrySet().iterator(); itr.hasNext();) {
            Map.Entry entry = (Map.Entry)itr.next();
            DelegatableThreadLocal dtl = (DelegatableThreadLocal)entry.getKey();
            dtl.onUndelegated(entry.getValue());
        }
        threadLocalMap.set(oldmap);
    }
}
